R3 Q8: Heterogeneity Analysis - Main Paper Method¶
Replicating Main Paper's Heterogeneity Approach¶
This notebook replicates the heterogeneity analysis method used in the main manuscript (line 238, trajectory_and_prs_cluster.R):
Method:
- Cluster by time-averaged signature loadings (k-means on mean theta across time)
- Project deviations from reference over time for each cluster
- Correlate with PRS to understand genetic differences between clusters
Key Difference from Deviation-Based Method:
- This method (main paper): Clusters patients based on their average signature levels, then visualizes how each cluster deviates from the population reference over time. More interpretable for clinical applications.
- Deviation-based method (R3_Q8_Heterogeneity_Continued): Clusters patients based on how their trajectories deviate from the population average. Better for pathway discovery.
Both approaches demonstrate heterogeneity, but serve different purposes.
Note¶
This notebook replicates the exact method from the main paper (trajectory_and_prs_cluster.R, line 70-364). For a complete implementation, see the helper script that will be created. This notebook provides the framework and key visualizations.
# ============================================================================
# SETUP: Import and Configure
# ============================================================================
import sys
import os
%load_ext autoreload
sys.path.append('/Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/')
from helper_py.pathway_discovery import load_full_data
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from scipy.stats import ttest_ind
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 300
print("="*80)
print("HETEROGENEITY ANALYSIS: MAIN PAPER METHOD")
print("="*80)
print("\nThis notebook replicates the method from trajectory_and_prs_cluster.R")
print("Method: Cluster by time-averaged signature loadings, then visualize")
print(" deviations from reference over time and correlate with PRS")
print("="*80)
================================================================================
HETEROGENEITY ANALYSIS: MAIN PAPER METHOD
================================================================================
This notebook replicates the method from trajectory_and_prs_cluster.R
Method: Cluster by time-averaged signature loadings, then visualize
deviations from reference over time and correlate with PRS
================================================================================
Load Data¶
Load thetas, Y matrix, disease names, processed IDs, PRS, and reference trajectories.
# ============================================================================
# LOAD DATA
# ============================================================================
# Load full data (Y, thetas, disease_names, processed_ids)
Y, thetas, disease_names, processed_ids = load_full_data()
# Calculate time-averaged thetas (mean across time dimension)
# This is what we cluster on in the main paper method
time_averaged_theta = thetas.mean(axis=2)
print(f"\n✅ Time-averaged theta shape: {time_averaged_theta.shape} (patients × signatures)")
# Load PRS data
prs_path = Path('/Users/sarahurbut/aladynoulli2/pyScripts/csv/prs_with_eid.csv')
print(f"\nLoading PRS data from: {prs_path}")
prs_df = pd.read_csv(prs_path)
prs_cols = [col for col in prs_df.columns if col != 'PatientID']
print(f"✅ PRS shape: {prs_df.shape}")
print(f"✅ PRS columns: {len(prs_cols)} PRS scores")
# Load reference trajectories (sig_refs)
sig_refs_path = Path('/Users/sarahurbut/aladynoulli2/pyScripts/csv/reference_thetas.csv')
print(f"\nLoading reference trajectories from: {sig_refs_path}")
sig_refs = pd.read_csv(sig_refs_path, header=0)
sig_refs = sig_refs.values # Convert to numpy array (K × T)
print(f"✅ Reference trajectories shape: {sig_refs.shape} (signatures × timepoints)")
print("\n" + "="*80)
print("✅ DATA LOADING COMPLETE")
print("="*80)
Loading full dataset... Loaded Y (full): torch.Size([407878, 348, 52]) Loaded thetas: (400000, 21, 52) Loaded 400000 processed IDs Subset Y to first 400K patients: torch.Size([400000, 348, 52]) Loaded 348 diseases Total patients with complete data: 400000 ✅ Time-averaged theta shape: (400000, 21) (patients × signatures) Loading PRS data from: /Users/sarahurbut/aladynoulli2/pyScripts/csv/prs_with_eid.csv ✅ PRS shape: (400000, 37) ✅ PRS columns: 36 PRS scores Loading reference trajectories from: /Users/sarahurbut/aladynoulli2/pyScripts/csv/reference_thetas.csv ✅ Reference trajectories shape: (21, 52) (signatures × timepoints) ================================================================================ ✅ DATA LOADING COMPLETE ================================================================================
Analysis: Main Paper Method¶
This replicates traj_func from trajectory_and_prs_cluster.R:
- Identify diseased patients for target disease
- Cluster by time-averaged signature loadings (k-means, k=3)
- Calculate deviations from reference over time for each cluster
- Correlate with PRS (mean PRS per cluster, Cohen's d)
Diseases analyzed:
- Myocardial infarction
- Malignant neoplasm of female breast
This will create all visualizations for both diseases.
# ============================================================================
# ANALYSIS FUNCTION: Run for a single disease
# ============================================================================
def analyze_disease_heterogeneity(target_disease, Y, thetas, time_averaged_theta, disease_names,
processed_ids, prs_df, prs_cols, sig_refs, n_clusters=3, random_state=42):
"""
Analyze heterogeneity for a single disease using main paper method.
Returns a dictionary with all results.
"""
print("\n" + "="*80)
print(f"ANALYZING: {target_disease}")
print("="*80)
# Find disease index
try:
disease_ix = disease_names.index(target_disease)
print(f"✅ Found '{target_disease}' at index {disease_ix}")
except ValueError:
print(f"❌ Disease '{target_disease}' not found")
print(f"Available diseases (first 10): {disease_names[:10]}")
return None
# Identify diseased patients (any timepoint)
diseased = np.where(Y[:, disease_ix, :].sum(axis=1) > 0)[0]
print(f"✅ Diseased patients: {len(diseased):,}")
if len(diseased) < n_clusters:
print(f"❌ Not enough patients for {n_clusters} clusters")
return None
# Get time-averaged theta for diseased patients
time_averaged_theta_diseased = time_averaged_theta[diseased, :]
print(f"✅ Time-averaged theta (diseased): {time_averaged_theta_diseased.shape}")
# Cluster by time-averaged signature loadings (k-means, k=3)
print(f"\n🔍 Clustering by time-averaged signature loadings (k={n_clusters})...")
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
clusters = kmeans.fit_predict(time_averaged_theta_diseased)
print(f"✅ Cluster sizes: {np.bincount(clusters + 1)}")
# Calculate deviations from reference over time for each cluster
K = thetas.shape[1] # Number of signatures
T = thetas.shape[2] # Number of timepoints
time_diff_by_cluster = np.zeros((n_clusters, K, T))
time_means_by_cluster = np.zeros((n_clusters, K, T))
print(f"\n📊 Calculating deviations from reference over time...")
for t in range(T):
# Get theta for this timepoint for diseased patients
time_spec_theta = thetas[diseased, :, t] # N_d × K
# Calculate mean theta per cluster for this timepoint
for c in range(n_clusters):
cluster_mask = clusters == c
if cluster_mask.sum() > 0:
time_means_by_cluster[c, :, t] = time_spec_theta[cluster_mask, :].mean(axis=0)
# Deviation from reference
time_diff_by_cluster[c, :, t] = time_means_by_cluster[c, :, t] - sig_refs[:, t]
print(f"✅ Calculated deviations for {n_clusters} clusters × {K} signatures × {T} timepoints")
# Load PRS for diseased patients
print(f"\n🧬 Loading PRS data for diseased patients...")
diseased_eids = processed_ids[diseased]
# Match PRS to diseased patients by eid
eid_to_prs_idx = {eid: idx for idx, eid in enumerate(prs_df['PatientID'].values)}
prs_matrix = np.zeros((len(diseased), len(prs_cols)))
prs_mask = np.zeros(len(diseased), dtype=bool)
for i, eid in enumerate(diseased_eids):
if eid in eid_to_prs_idx:
prs_idx = eid_to_prs_idx[eid]
prs_matrix[i, :] = prs_df.iloc[prs_idx][prs_cols].values
prs_mask[i] = True
print(f"✅ PRS data available for {prs_mask.sum():,} / {len(diseased):,} patients")
print(f"✅ PRS matrix shape: {prs_matrix.shape}")
# Calculate PRS means by cluster
print(f"\n📈 Calculating PRS means by cluster...")
prs_means_by_cluster = {}
for c in range(n_clusters):
cluster_mask = clusters == c
cluster_with_prs = cluster_mask & prs_mask
if cluster_with_prs.sum() > 0:
prs_means_by_cluster[c] = prs_matrix[cluster_with_prs, :].mean(axis=0)
print(f" Cluster {c+1}: {cluster_with_prs.sum():,} patients with PRS")
else:
prs_means_by_cluster[c] = np.zeros(len(prs_cols))
return {
'disease_name': target_disease,
'disease_ix': disease_ix,
'n_diseased': len(diseased),
'diseased': diseased,
'clusters': clusters,
'time_diff_by_cluster': time_diff_by_cluster,
'time_means_by_cluster': time_means_by_cluster,
'time_averaged_theta_diseased': time_averaged_theta_diseased,
'prs_matrix': prs_matrix,
'prs_mask': prs_mask,
'prs_means_by_cluster': prs_means_by_cluster,
'K': K,
'T': T,
'n_clusters': n_clusters
}
# ============================================================================
# RUN ANALYSIS FOR BOTH DISEASES
# ============================================================================
target_diseases = ["Myocardial infarction", "Malignant neoplasm of female breast","Major depressive disorder"]
results_dict = {}
for target_disease in target_diseases:
results = analyze_disease_heterogeneity(
target_disease=target_disease,
Y=Y,
thetas=thetas,
time_averaged_theta=time_averaged_theta,
disease_names=disease_names,
processed_ids=processed_ids,
prs_df=prs_df,
prs_cols=prs_cols,
sig_refs=sig_refs,
n_clusters=3,
random_state=42
)
if results:
results_dict[target_disease] = results
print("\n" + "="*80)
print("✅ ALL ANALYSES COMPLETE")
print("="*80)
================================================================================ ANALYZING: Myocardial infarction ================================================================================ ✅ Found 'Myocardial infarction' at index 112 ✅ Diseased patients: 24,920 ✅ Time-averaged theta (diseased): (24920, 21) 🔍 Clustering by time-averaged signature loadings (k=3)... ✅ Cluster sizes: [ 0 6775 7655 10490] 📊 Calculating deviations from reference over time... ✅ Calculated deviations for 3 clusters × 21 signatures × 52 timepoints 🧬 Loading PRS data for diseased patients... ✅ PRS data available for 24,920 / 24,920 patients ✅ PRS matrix shape: (24920, 36) 📈 Calculating PRS means by cluster... Cluster 1: 6,775 patients with PRS Cluster 2: 7,655 patients with PRS Cluster 3: 10,490 patients with PRS ================================================================================ ANALYZING: Malignant neoplasm of female breast ================================================================================ ✅ Found 'Malignant neoplasm of female breast' at index 17 ✅ Diseased patients: 17,302 ✅ Time-averaged theta (diseased): (17302, 21) 🔍 Clustering by time-averaged signature loadings (k=3)... ✅ Cluster sizes: [ 0 1419 5716 10167] 📊 Calculating deviations from reference over time... ✅ Calculated deviations for 3 clusters × 21 signatures × 52 timepoints 🧬 Loading PRS data for diseased patients... ✅ PRS data available for 17,302 / 17,302 patients ✅ PRS matrix shape: (17302, 36) 📈 Calculating PRS means by cluster... Cluster 1: 1,419 patients with PRS Cluster 2: 5,716 patients with PRS Cluster 3: 10,167 patients with PRS ================================================================================ ANALYZING: Major depressive disorder ================================================================================ ✅ Found 'Major depressive disorder' at index 66 ✅ Diseased patients: 30,549 ✅ Time-averaged theta (diseased): (30549, 21) 🔍 Clustering by time-averaged signature loadings (k=3)... ✅ Cluster sizes: [ 0 8599 12109 9841] 📊 Calculating deviations from reference over time... ✅ Calculated deviations for 3 clusters × 21 signatures × 52 timepoints 🧬 Loading PRS data for diseased patients... ✅ PRS data available for 30,549 / 30,549 patients ✅ PRS matrix shape: (30549, 36) 📈 Calculating PRS means by cluster... Cluster 1: 8,599 patients with PRS Cluster 2: 12,109 patients with PRS Cluster 3: 9,841 patients with PRS ================================================================================ ✅ ALL ANALYSES COMPLETE ================================================================================
# Analysis for both diseases is now in cell 6
# This cell is kept for compatibility but is no longer needed
pass
Visualizations¶
Create the plots from the main paper method:
- Stacked area plot of deviations from reference over time
- Signature Cohen's d heatmap
- PRS heatmap by cluster
# ============================================================================
# PLOT 1: Stacked Area Plot of Deviations from Reference (WITH CONSISTENT COLORS)
# ============================================================================
output_dir = 'heterogeneity_main_paper_output'
os.makedirs(output_dir, exist_ok=True)
# Color palette matching Figure3_Individual_Trajectories.ipynb
# Uses seaborn's tab20 palette, with tab20b for signature 20 to ensure all 21 signatures have unique colors
# Note: Signature 5 (cardiovascular) swapped to red for biological interpretability
def get_signature_colors(K):
"""Get color palette matching Figure3 (seaborn tab20 palette with tab20b for 21st signature).
Signature 5 (cardiovascular) is assigned red color for biological interpretability.
"""
if K <= 20:
# Use tab20 for up to 20 signatures
colors = sns.color_palette("tab20", K)
else:
# Use tab20 for first 20, then tab20b for signature 20 (21st signature)
colors_20 = sns.color_palette("tab20", 20)
colors_b = sns.color_palette("tab20b", 20)
colors = list(colors_20) + [colors_b[0]] # First color from tab20b for signature 20
if K > 21:
# If somehow K > 21, use additional colors from tab20b
colors.extend(colors_b[1:K-20])
# Swap signature 5 and 6: Sig 5 (cardiovascular) gets red, Sig 6 gets light green
if K > 5:
colors[5], colors[6] = colors[6], colors[5]
# Convert to numpy array for consistency with matplotlib stackplot
return np.array(colors)
for target_disease, results in results_dict.items():
print(f"\n{'='*80}")
print(f"Creating stacked area plot for: {target_disease}")
print(f"{'='*80}")
time_diff_by_cluster = results['time_diff_by_cluster']
clusters = results['clusters']
K = results['K']
T = results['T']
n_clusters = results['n_clusters']
# Get consistent color palette
sig_colors = get_signature_colors(K)
# Reshape time_diff_by_cluster for plotting
plot_data = []
ages = np.arange(30, 30 + T) # Age = timepoint + 30
for c in range(n_clusters):
for k in range(K):
for t in range(T):
plot_data.append({
'Cluster': c + 1,
'Signature': k,
'Age': ages[t],
'Deviation': time_diff_by_cluster[c, k, t]
})
plot_df = pd.DataFrame(plot_data)
# Create stacked area plot (one subplot per cluster)
fig, axes = plt.subplots(n_clusters, 1, figsize=(14, 5 * n_clusters), sharex=True)
for c in range(n_clusters):
cluster_data = plot_df[plot_df['Cluster'] == c + 1]
# Pivot for stacked area plot
pivot_data = cluster_data.pivot_table(
index='Age',
columns='Signature',
values='Deviation',
aggfunc='mean'
)
# Stacked area plot with consistent colors
axes[c].stackplot(
pivot_data.index,
*[pivot_data[col] for col in pivot_data.columns],
labels=[f'Sig {col}' for col in pivot_data.columns],
alpha=0.7,
colors=[sig_colors[int(col)] for col in pivot_data.columns]
)
cluster_size = (clusters == c).sum()
axes[c].set_title(f'Cluster {c + 1} (n={cluster_size:,})', fontsize=14, fontweight='bold')
axes[c].set_ylabel('Deviation from Reference', fontsize=12)
axes[c].axhline(y=0, color='black', linestyle='--', linewidth=0.5)
axes[c].grid(True, alpha=0.3)
axes[c].set_xlim(ages[0], ages[-1])
# Add legend
if K <= 10:
axes[c].legend(loc='upper left', fontsize=8, ncol=2)
else:
axes[c].legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=7, ncol=1)
axes[-1].set_xlabel('Age', fontsize=12)
fig.suptitle(
f'Signature Deviations from Reference: {target_disease}',
fontsize=16,
fontweight='bold'
)
plt.tight_layout()
# Save plot
output_path = os.path.join(output_dir, f'stacked_deviations_{target_disease.replace(" ", "_")}.pdf')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"✅ Saved: {output_path}")
plt.show()
#
================================================================================ Creating stacked area plot for: Myocardial infarction ================================================================================ ✅ Saved: heterogeneity_main_paper_output/stacked_deviations_Myocardial_infarction.pdf
================================================================================ Creating stacked area plot for: Malignant neoplasm of female breast ================================================================================ ✅ Saved: heterogeneity_main_paper_output/stacked_deviations_Malignant_neoplasm_of_female_breast.pdf
================================================================================ Creating stacked area plot for: Major depressive disorder ================================================================================ ✅ Saved: heterogeneity_main_paper_output/stacked_deviations_Major_depressive_disorder.pdf
# ============================================================================
# PLOT: LINE PLOT - Clean version with proper colors
# ============================================================================
# Same color function as Cell 9 - matching Figure3 palette
# Note: This function will be overridden by the one in Cell 9 if both cells are run
# Keeping for consistency if this cell is run independently
def get_signature_colors(K):
"""Get color palette matching Figure3 (seaborn tab20 palette with tab20b for 21st signature).
Signature 5 (cardiovascular) is assigned red color for biological interpretability.
"""
if K <= 20:
# Use tab20 for up to 20 signatures
colors = sns.color_palette("tab20", K)
else:
# Use tab20 for first 20, then tab20b for signature 20 (21st signature)
colors_20 = sns.color_palette("tab20", 20)
colors_b = sns.color_palette("tab20b", 20)
colors = list(colors_20) + [colors_b[0]] # First color from tab20b for signature 20
if K > 21:
# If somehow K > 21, use additional colors from tab20b
colors.extend(colors_b[1:K-20])
# Swap signature 5 and 6: Sig 5 (cardiovascular) gets red, Sig 6 gets light green
if K > 5:
colors[5], colors[6] = colors[6], colors[5]
# Convert to numpy array for consistency with matplotlib
return np.array(colors)
for target_disease, results in results_dict.items():
print(f"\n{'='*80}")
print(f"Creating line plot for: {target_disease}")
print(f"{'='*80}")
time_diff_by_cluster = results['time_diff_by_cluster']
clusters = results['clusters']
K = results['K']
T = results['T']
n_clusters = results['n_clusters']
sig_colors = get_signature_colors(K)
ages = np.arange(30, 30 + T)
# Key signatures to emphasize with thicker lines
key_signatures = {5, 6, 8, 9}
fig, axes = plt.subplots(n_clusters, 1, figsize=(14, 5 * n_clusters), sharex=True)
for c in range(n_clusters):
cluster_size = (clusters == c).sum()
# Plot all signatures
for k in range(K):
deviation_trajectory = time_diff_by_cluster[c, k, :]
# Thicker lines for key signatures
linewidth = 3.0 if k in key_signatures else 1.5
alpha = 0.9 if k in key_signatures else 0.6
axes[c].plot(
ages,
deviation_trajectory,
label=f'Sig {k}',
alpha=alpha,
linewidth=linewidth,
color=sig_colors[k],
zorder=10 if k in key_signatures else 1 # Key sigs on top
)
axes[c].set_title(f'Cluster {c + 1} (n={cluster_size:,})', fontsize=14, fontweight='bold')
axes[c].set_ylabel('Deviation from Reference', fontsize=12)
axes[c].axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[c].grid(True, alpha=0.3, linestyle='--')
axes[c].set_xlim(ages[0], ages[-1])
# No individual legends
axes[-1].set_xlabel('Age', fontsize=12)
# Single shared legend for all signatures
from matplotlib.patches import Patch
all_sig_handles = []
all_sig_labels = []
for k in range(K):
patch = Patch(facecolor=sig_colors[k], alpha=0.9, edgecolor='black', linewidth=0.5)
all_sig_handles.append(patch)
all_sig_labels.append(f'Sig {k}')
fig.legend(
all_sig_handles,
all_sig_labels,
loc='center right',
bbox_to_anchor=(0.98, 0.5),
fontsize=8,
ncol=1,
framealpha=0.9
)
fig.suptitle(
f'Signature Deviations from Reference: {target_disease}',
fontsize=16,
fontweight='bold'
)
plt.tight_layout(rect=[0, 0, 0.95, 1]) # Leave space on right for legend
# Save plot
output_path = os.path.join(output_dir, f'line_deviations_{target_disease.replace(" ", "_")}.pdf')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"✅ Saved: {output_path}")
plt.show()
================================================================================ Creating line plot for: Myocardial infarction ================================================================================ ✅ Saved: heterogeneity_main_paper_output/line_deviations_Myocardial_infarction.pdf
================================================================================ Creating line plot for: Malignant neoplasm of female breast ================================================================================ ✅ Saved: heterogeneity_main_paper_output/line_deviations_Malignant_neoplasm_of_female_breast.pdf
================================================================================ Creating line plot for: Major depressive disorder ================================================================================ ✅ Saved: heterogeneity_main_paper_output/line_deviations_Major_depressive_disorder.pdf
# ============================================================================
# PLOT: SEPARATED STACKED - Single shared legend
# ============================================================================
for target_disease, results in results_dict.items():
time_diff_by_cluster = results['time_diff_by_cluster']
clusters = results['clusters']
K = results['K']
T = results['T']
n_clusters = results['n_clusters']
sig_colors = get_signature_colors(K)
ages = np.arange(30, 30 + T)
fig, axes = plt.subplots(n_clusters, 1, figsize=(16, 5 * n_clusters), sharex=True)
# Collect all labels and handles for shared legend
all_handles = []
all_labels = []
for c in range(n_clusters):
cluster_size = (clusters == c).sum()
# Separate positive and negative, keeping track of signature indices
positive_stacks = []
negative_stacks = []
pos_sig_indices = []
neg_sig_indices = []
for k in range(K):
deviation = time_diff_by_cluster[c, k, :]
positive = np.maximum(deviation, 0)
negative = np.minimum(deviation, 0)
if np.any(positive > 0):
positive_stacks.append(positive)
pos_sig_indices.append(k)
if np.any(negative < 0):
negative_stacks.append(negative)
neg_sig_indices.append(k)
# Stack positive deviations ABOVE zero
if positive_stacks:
handles_pos = axes[c].stackplot(
ages,
*positive_stacks,
labels=[f'Sig {k}' for k in pos_sig_indices],
alpha=0.7,
colors=[sig_colors[k] for k in pos_sig_indices],
baseline='zero'
)
# Collect handles for legend (only on first cluster)
if c == 0:
all_handles.extend(handles_pos)
all_labels.extend([f'Sig {k}' for k in pos_sig_indices])
# Stack negative deviations BELOW zero
if negative_stacks:
negative_stacks_inverted = [-n for n in negative_stacks]
handles_neg = axes[c].stackplot(
ages,
*negative_stacks_inverted,
labels=[f'Sig {k}' for k in neg_sig_indices],
alpha=0.7,
colors=[sig_colors[k] for k in neg_sig_indices],
baseline='zero'
)
# Collect handles for legend (only on first cluster)
if c == 0:
all_handles.extend(handles_neg)
all_labels.extend([f'Sig {k}' for k in neg_sig_indices])
axes[c].axhline(y=0, color='black', linestyle='-', linewidth=2)
axes[c].set_title(f'Cluster {c + 1} (n={cluster_size:,})', fontsize=14, fontweight='bold')
axes[c].set_ylabel('Deviation from Reference', fontsize=12)
axes[c].grid(True, alpha=0.3)
axes[c].set_xlim(ages[0], ages[-1])
# No legend on individual subplots
axes[-1].set_xlabel('Age', fontsize=12)
# Add single shared legend (all 21 signatures)
# Create legend with all signatures, even if not all appear in first cluster
all_sig_handles = []
all_sig_labels = []
for k in range(K):
# Create a patch for each signature color
from matplotlib.patches import Patch
patch = Patch(facecolor=sig_colors[k], alpha=0.7, edgecolor='black', linewidth=0.5)
all_sig_handles.append(patch)
all_sig_labels.append(f'Sig {k}')
# Place legend on the right side of the figure
fig.legend(
all_sig_handles,
all_sig_labels,
loc='center right',
bbox_to_anchor=(0.98, 0.5),
fontsize=8,
ncol=1,
framealpha=0.9
)
fig.suptitle(f'Signature Deviations (Separated): {target_disease}', fontsize=16, fontweight='bold')
plt.tight_layout(rect=[0, 0, 0.95, 1]) # Leave space on right for legend
plt.show()
# ============================================================================
# PLOT 2: Signature Cohen's d Heatmap
# ============================================================================
# This replicates the signature Cohen's d heatmap from trajectory_and_prs_cluster.R (p_sig)
for target_disease, results in results_dict.items():
print(f"\n{'='*80}")
print(f"Creating signature Cohen's d heatmap for: {target_disease}")
print(f"{'='*80}")
clusters = results['clusters']
time_averaged_theta_diseased = results['time_averaged_theta_diseased']
K = results['K']
n_clusters = results['n_clusters']
# Calculate Cohen's d for signatures
signature_cohens_d = []
for c in range(n_clusters):
in_cluster = clusters == c
out_cluster = clusters != c
for k in range(K):
mean_in = time_averaged_theta_diseased[in_cluster, k].mean()
mean_out = time_averaged_theta_diseased[out_cluster, k].mean()
std_in = time_averaged_theta_diseased[in_cluster, k].std()
std_out = time_averaged_theta_diseased[out_cluster, k].std()
n_in = in_cluster.sum()
n_out = out_cluster.sum()
# Pooled standard deviation
pooled_sd = np.sqrt(((n_in - 1) * std_in**2 + (n_out - 1) * std_out**2) / (n_in + n_out - 2))
d = (mean_in - mean_out) / pooled_sd if pooled_sd > 0 else 0
# P-value
try:
_, pval = ttest_ind(
time_averaged_theta_diseased[in_cluster, k],
time_averaged_theta_diseased[out_cluster, k]
)
except:
pval = np.nan
signature_cohens_d.append({
'Cluster': c + 1,
'Signature': k,
'Mean_In': mean_in,
'Mean_Out': mean_out,
'Cohen_d': d,
'p_value': pval
})
signature_cohens_d_df = pd.DataFrame(signature_cohens_d)
# Create heatmap
pivot_sig = signature_cohens_d_df.pivot_table(
index='Signature',
columns='Cluster',
values='Cohen_d',
aggfunc='mean'
)
fig, ax = plt.subplots(figsize=(6, 12))
sns.heatmap(
pivot_sig,
annot=True,
fmt='.2f',
cmap='RdBu_r',
center=0,
vmin=-3,
vmax=3,
cbar_kws={'label': "Cohen's d"},
ax=ax,
linewidths=0.5
)
ax.set_title(
f"Signature Cohen's d by Cluster: {target_disease}",
fontsize=14,
fontweight='bold'
)
ax.set_xlabel('Cluster', fontsize=12)
ax.set_ylabel('Signature', fontsize=12)
plt.tight_layout()
output_path = os.path.join(output_dir, f'signature_cohens_d_{target_disease.replace(" ", "_")}.pdf')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"✅ Saved: {output_path}")
plt.show()
# Save results
csv_path = os.path.join(output_dir, f'signature_cohens_d_{target_disease.replace(" ", "_")}.csv')
signature_cohens_d_df.to_csv(csv_path, index=False)
print(f"✅ Saved CSV: {csv_path}")
================================================================================ Creating signature Cohen's d heatmap for: Myocardial infarction ================================================================================ ✅ Saved: heterogeneity_main_paper_output/signature_cohens_d_Myocardial_infarction.pdf
✅ Saved CSV: heterogeneity_main_paper_output/signature_cohens_d_Myocardial_infarction.csv ================================================================================ Creating signature Cohen's d heatmap for: Malignant neoplasm of female breast ================================================================================ ✅ Saved: heterogeneity_main_paper_output/signature_cohens_d_Malignant_neoplasm_of_female_breast.pdf
✅ Saved CSV: heterogeneity_main_paper_output/signature_cohens_d_Malignant_neoplasm_of_female_breast.csv ================================================================================ Creating signature Cohen's d heatmap for: Major depressive disorder ================================================================================ ✅ Saved: heterogeneity_main_paper_output/signature_cohens_d_Major_depressive_disorder.pdf
✅ Saved CSV: heterogeneity_main_paper_output/signature_cohens_d_Major_depressive_disorder.csv
# ============================================================================
# PLOT 3: PRS Heatmap by Cluster
# ============================================================================
# This replicates the PRS heatmap from trajectory_and_prs_cluster.R (p2)
for target_disease, results in results_dict.items():
print(f"\n{'='*80}")
print(f"Creating PRS heatmap for: {target_disease}")
print(f"{'='*80}")
clusters = results['clusters']
prs_matrix = results['prs_matrix']
prs_mask = results['prs_mask']
n_clusters = results['n_clusters']
# Calculate Cohen's d for PRS and create mean PRS matrix
prs_cohens_d = []
prs_means_list = []
for c in range(n_clusters):
in_cluster = (clusters == c) & prs_mask
out_cluster = (clusters != c) & prs_mask
for r, prs_name in enumerate(prs_cols):
if in_cluster.sum() > 0 and out_cluster.sum() > 0:
mean_in = prs_matrix[in_cluster, r].mean()
mean_out = prs_matrix[out_cluster, r].mean()
std_in = prs_matrix[in_cluster, r].std()
std_out = prs_matrix[out_cluster, r].std()
n_in = in_cluster.sum()
n_out = out_cluster.sum()
pooled_sd = np.sqrt(((n_in - 1) * std_in**2 + (n_out - 1) * std_out**2) / (n_in + n_out - 2))
d = (mean_in - mean_out) / pooled_sd if pooled_sd > 0 else 0
prs_cohens_d.append({
'Cluster': c + 1,
'PRS': prs_name,
'Mean_In': mean_in,
'Mean_Out': mean_out,
'Cohen_d': d
})
prs_means_list.append({
'Cluster': c + 1,
'PRS': prs_name,
'Mean': mean_in
})
prs_cohens_d_df = pd.DataFrame(prs_cohens_d)
prs_means_df = pd.DataFrame(prs_means_list)
# Get top PRS by absolute Cohen's d
prs_summary = prs_cohens_d_df.groupby('PRS')['Cohen_d'].apply(lambda x: x.abs().max()).sort_values(ascending=False)
top_prs = prs_summary.head(20).index.tolist()
# Create heatmap for top PRS
prs_subset = prs_means_df[prs_means_df['PRS'].isin(top_prs)]
pivot_prs = prs_subset.pivot_table(
index='PRS',
columns='Cluster',
values='Mean',
aggfunc='mean'
)
fig, ax = plt.subplots(figsize=(6, 10))
sns.heatmap(
pivot_prs,
annot=True,
fmt='.2f',
cmap='RdBu_r',
center=0,
cbar_kws={'label': 'Mean PRS'},
ax=ax,
linewidths=0.5
)
ax.set_title(
f'PRS Means by Cluster: {target_disease}',
fontsize=14,
fontweight='bold'
)
ax.set_xlabel('Cluster', fontsize=12)
ax.set_ylabel('PRS', fontsize=12)
plt.tight_layout()
output_path = os.path.join(output_dir, f'prs_means_{target_disease.replace(" ", "_")}.pdf')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"✅ Saved: {output_path}")
plt.show()
# Save results
csv_path = os.path.join(output_dir, f'prs_cohens_d_{target_disease.replace(" ", "_")}.csv')
prs_cohens_d_df.to_csv(csv_path, index=False)
print(f"✅ Saved CSV: {csv_path}")
================================================================================ Creating PRS heatmap for: Myocardial infarction ================================================================================ ✅ Saved: heterogeneity_main_paper_output/prs_means_Myocardial_infarction.pdf
✅ Saved CSV: heterogeneity_main_paper_output/prs_cohens_d_Myocardial_infarction.csv ================================================================================ Creating PRS heatmap for: Malignant neoplasm of female breast ================================================================================ ✅ Saved: heterogeneity_main_paper_output/prs_means_Malignant_neoplasm_of_female_breast.pdf
✅ Saved CSV: heterogeneity_main_paper_output/prs_cohens_d_Malignant_neoplasm_of_female_breast.csv ================================================================================ Creating PRS heatmap for: Major depressive disorder ================================================================================ ✅ Saved: heterogeneity_main_paper_output/prs_means_Major_depressive_disorder.pdf
✅ Saved CSV: heterogeneity_main_paper_output/prs_cohens_d_Major_depressive_disorder.csv
Summary¶
This notebook replicates the main paper's heterogeneity method with all visualizations:
- Clustering: Patients clustered by time-averaged signature loadings (k-means, k=3)
- Deviations: Stacked area plot showing how each cluster deviates from population reference over time
- Signature Analysis: Heatmap of Cohen's d for signatures by cluster
- PRS Correlation: Heatmap of mean PRS by cluster
Key Difference from Deviation-Based Method:
- This method (main paper): Clusters first on average loadings, then visualizes deviations (more interpretable for clinical stratification)
- Deviation-based method (R3_Q8_Heterogeneity_Continued): Clusters by deviations directly (better for pathway discovery)
Both approaches demonstrate heterogeneity, but serve different purposes.
Output files saved to: heterogeneity_main_paper_output/
# ============================================================================
# PLOT: Average theta per cluster (normalized, not deviation from reference)
# ============================================================================
for target_disease, results in results_dict.items():
print(f"\n{'='*80}")
print(f"Creating average theta plot (normalized) for: {target_disease}")
print(f"{'='*80}")
time_means_by_cluster = results['time_means_by_cluster'] # Use actual theta, not deviation
clusters = results['clusters']
K = results['K']
T = results['T']
n_clusters = results['n_clusters']
sig_colors = get_signature_colors(K)
ages = np.arange(30, 30 + T)
key_signatures = {5, 6, 8, 9}
fig, axes = plt.subplots(n_clusters, 1, figsize=(14, 5 * n_clusters), sharex=True)
for c in range(n_clusters):
cluster_size = (clusters == c).sum()
# Get mean theta for this cluster at each timepoint
cluster_theta = time_means_by_cluster[c, :, :] # (K, T)
# Normalize so signatures sum to 1 at each timepoint (if desired)
# Or just use raw theta values
# Option A: Normalize to sum to 1
cluster_theta_normalized = cluster_theta / (cluster_theta.sum(axis=0, keepdims=True) + 1e-10)
# Option B: Use raw theta (uncomment if you prefer)
# cluster_theta_normalized = cluster_theta
# Stacked area plot of normalized theta
axes[c].stackplot(
ages,
*[cluster_theta_normalized[k, :] for k in range(K)],
labels=[f'Sig {k}' for k in range(K)],
alpha=0.7,
colors=[sig_colors[k] for k in range(K)],
baseline='zero'
)
axes[c].set_title(f'Cluster {c + 1} (n={cluster_size:,})', fontsize=14, fontweight='bold')
axes[c].set_ylabel('Normalized Signature Loading', fontsize=12)
axes[c].set_ylim(0, 1) # Since normalized, should sum to 1
axes[c].grid(True, alpha=0.3)
axes[c].set_xlim(ages[0], ages[-1])
axes[-1].set_xlabel('Age', fontsize=12)
# Single shared legend
from matplotlib.patches import Patch
all_sig_handles = []
all_sig_labels = []
for k in range(K):
patch = Patch(facecolor=sig_colors[k], alpha=0.7, edgecolor='black', linewidth=0.5)
all_sig_handles.append(patch)
all_sig_labels.append(f'Sig {k}')
fig.legend(
all_sig_handles,
all_sig_labels,
loc='center right',
bbox_to_anchor=(0.98, 0.5),
fontsize=8,
ncol=1,
framealpha=0.9
)
fig.suptitle(
f'Normalized Signature Loadings by Cluster: {target_disease}',
fontsize=16,
fontweight='bold'
)
plt.tight_layout(rect=[0, 0, 0.95, 1])
output_path = os.path.join(output_dir, f'normalized_theta_{target_disease.replace(" ", "_")}.pdf')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"✅ Saved: {output_path}")
plt.show()
================================================================================ Creating average theta plot (normalized) for: Myocardial infarction ================================================================================ ✅ Saved: heterogeneity_main_paper_output/normalized_theta_Myocardial_infarction.pdf
================================================================================ Creating average theta plot (normalized) for: Malignant neoplasm of female breast ================================================================================ ✅ Saved: heterogeneity_main_paper_output/normalized_theta_Malignant_neoplasm_of_female_breast.pdf
================================================================================ Creating average theta plot (normalized) for: Major depressive disorder ================================================================================ ✅ Saved: heterogeneity_main_paper_output/normalized_theta_Major_depressive_disorder.pdf
# ============================================================================
# PLOT: Line plots with transparent filled areas (prettier than plain lines)
# ============================================================================
for target_disease, results in results_dict.items():
print(f"\n{'='*80}")
print(f"Creating line plot with filled areas for: {target_disease}")
print(f"{'='*80}")
time_diff_by_cluster = results['time_diff_by_cluster']
clusters = results['clusters']
K = results['K']
T = results['T']
n_clusters = results['n_clusters']
sig_colors = get_signature_colors(K)
ages = np.arange(30, 30 + T)
# Key signatures to emphasize
key_signatures = {5, 6, 8, 9}
fig, axes = plt.subplots(n_clusters, 1, figsize=(14, 5 * n_clusters), sharex=True)
for c in range(n_clusters):
cluster_size = (clusters == c).sum()
# Plot all signatures with filled areas
for k in range(K):
deviation_trajectory = time_diff_by_cluster[c, k, :]
# Thicker lines for key signatures
linewidth = 3.0 if k in key_signatures else 1.5
alpha_line = 0.9 if k in key_signatures else 0.6
alpha_fill = 0.3 if k in key_signatures else 0.15 # More transparent fill
# Fill area below (or above if negative)
axes[c].fill_between(
ages,
0,
deviation_trajectory,
alpha=alpha_fill,
color=sig_colors[k],
zorder=1 if k in key_signatures else 0
)
# Line on top
axes[c].plot(
ages,
deviation_trajectory,
label=f'Sig {k}',
alpha=alpha_line,
linewidth=linewidth,
color=sig_colors[k],
zorder=10 if k in key_signatures else 2
)
axes[c].set_title(f'Cluster {c + 1} (n={cluster_size:,})', fontsize=14, fontweight='bold')
axes[c].set_ylabel('Deviation from Reference', fontsize=12)
axes[c].axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[c].grid(True, alpha=0.3, linestyle='--')
axes[c].set_xlim(ages[0], ages[-1])
axes[-1].set_xlabel('Age', fontsize=12)
# Single shared legend
from matplotlib.patches import Patch
all_sig_handles = []
all_sig_labels = []
for k in range(K):
patch = Patch(facecolor=sig_colors[k], alpha=0.9, edgecolor='black', linewidth=0.5)
all_sig_handles.append(patch)
all_sig_labels.append(f'Sig {k}')
fig.legend(
all_sig_handles,
all_sig_labels,
loc='center right',
bbox_to_anchor=(0.98, 0.5),
fontsize=8,
ncol=1,
framealpha=0.9
)
fig.suptitle(
f'Signature Deviations from Reference: {target_disease}',
fontsize=16,
fontweight='bold'
)
plt.tight_layout(rect=[0, 0, 0.95, 1])
output_path = os.path.join(output_dir, f'line_filled_{target_disease.replace(" ", "_")}.pdf')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"✅ Saved: {output_path}")
plt.show()
================================================================================ Creating line plot with filled areas for: Myocardial infarction ================================================================================ ✅ Saved: /Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/Apps/Overleaf/Aladynoulli_Nature/paper_figs/fig3/line_filled_Myocardial_infarction.pdf
================================================================================ Creating line plot with filled areas for: Malignant neoplasm of female breast ================================================================================ ✅ Saved: /Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/Apps/Overleaf/Aladynoulli_Nature/paper_figs/fig3/line_filled_Malignant_neoplasm_of_female_breast.pdf
================================================================================ Creating line plot with filled areas for: Major depressive disorder ================================================================================ ✅ Saved: /Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/Apps/Overleaf/Aladynoulli_Nature/paper_figs/fig3/line_filled_Major_depressive_disorder.pdf
# Test if heterogeneity clustering is similar enough
%load_ext autoreload
%autoreload 2
import subprocess
import sys
script_path = '/Users/sarahurbut/aladynoulli2/claudefile/test_heterogeneity_clustering_similarity.py'
cmd = [sys.executable, script_path]
print("Testing clustering similarity...")
print(f"Command: {' '.join(cmd)}")
print("="*80)
result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout)
if result.stderr:
print("\nSTDERR:")
print(result.stderr)
if result.returncode == 0:
print("\n✓ Test completed successfully!")
else:
print(f"\n✗ Test failed with return code {result.returncode}")
Testing clustering similarity...
Command: /opt/miniconda3/envs/new_env_pyro2/bin/python /Users/sarahurbut/aladynoulli2/claudefile/test_heterogeneity_clustering_similarity.py
================================================================================
================================================================================
TESTING HETEROGENEITY CLUSTERING SIMILARITY
================================================================================
This script tests if cluster assignments are similar enough
that heterogeneity analysis doesn't need to be remade.
Loading Y matrix and disease names...
Y shape: (400000, 348, 52)
Diseases: 348
Loading original thetas from training run...
Original thetas shape: (400000, 21, 52)
Note: These are from original training run (what heterogeneity analysis uses)
Loading lambda from unregularized TRAINING batches...
(These are what fixedgk predictions are based on)
Found 40 model files
Loaded enrollment_model_VECTORIZED_W0.0001_nolr_batch_0_10000.pt: lambda shape (10000, 21, 52)
Loaded enrollment_model_VECTORIZED_W0.0001_nolr_batch_10000_20000.pt: lambda shape (10000, 21, 52)
Loaded enrollment_model_VECTORIZED_W0.0001_nolr_batch_20000_30000.pt: lambda shape (10000, 21, 52)
Loaded enrollment_model_VECTORIZED_W0.0001_nolr_batch_30000_40000.pt: lambda shape (10000, 21, 52)
Loaded enrollment_model_VECTORIZED_W0.0001_nolr_batch_40000_50000.pt: lambda shape (10000, 21, 52)
Combined lambda shape: (400000, 21, 52)
Converting lambda to theta (softmax)...
Theta shape: (400000, 21, 52)
================================================================================
COMPARING CLUSTER ASSIGNMENTS
================================================================================
Time-averaged theta shapes:
Original (regularized): (400000, 21)
Unregularized: (400000, 21)
================================================================================
CHECKING CORRELATION OF TIME-AVERAGED THETAS
================================================================================
Sampling 100,000 patients for correlation check
Overall correlation (all signatures × patients):
Pearson r = 1.000000 (p < 0.00e+00)
Per-signature correlations:
Signature 0: r = 0.999999
Signature 1: r = 1.000000
Signature 2: r = 1.000000
Signature 3: r = 1.000000
Signature 4: r = 1.000000
Signature 19: r = 0.999998
Signature 20: r = 1.000000
Mean per-signature correlation: 0.999999
Min per-signature correlation: 0.999992
Max per-signature correlation: 1.000000
============================================================
Disease: Myocardial infarction
Diseased patients: 24,920
Clustering original (regularized) thetas...
Clustering unregularized thetas...
Results:
Adjusted Rand Index: 0.9914
Cluster overlap (after label matching): 99.7% (24850/24920 patients)
Mean centroid distance (after remapping): 0.000176
Original (regularized) cluster sizes: [6775, 7655, 10490]
Unregularized cluster sizes (original labels): [6746, 10532, 7642]
Unregularized cluster sizes (remapped to match): [6746, 7642, 10532]
Label mapping: {0: 0, 2: 1, 1: 2}
============================================================
Disease: Malignant neoplasm of female breast
Diseased patients: 17,302
Clustering original (regularized) thetas...
Clustering unregularized thetas...
Results:
Adjusted Rand Index: 0.9998
Cluster overlap (after label matching): 100.0% (17301/17302 patients)
Mean centroid distance (after remapping): 0.000008
Original (regularized) cluster sizes: [1419, 5716, 10167]
Unregularized cluster sizes (original labels): [1419, 5715, 10168]
Unregularized cluster sizes (remapped to match): [1419, 5715, 10168]
Label mapping: {0: 0, 1: 1, 2: 2}
============================================================
Disease: Major depressive disorder
Diseased patients: 30,549
Clustering original (regularized) thetas...
Clustering unregularized thetas...
Results:
Adjusted Rand Index: 0.9995
Cluster overlap (after label matching): 100.0% (30544/30549 patients)
Mean centroid distance (after remapping): 0.000009
Original (regularized) cluster sizes: [8599, 12109, 9841]
Unregularized cluster sizes (original labels): [8597, 12108, 9844]
Unregularized cluster sizes (remapped to match): [8597, 12108, 9844]
Label mapping: {0: 0, 1: 1, 2: 2}
================================================================================
SUMMARY
================================================================================
Cluster Similarity Metrics:
disease n_patients ari overlap_pct mean_centroid_distance
Myocardial infarction 24920 0.991375 99.719101 0.000176
Malignant neoplasm of female breast 17302 0.999786 99.994220 0.000008
Major depressive disorder 30549 0.999521 99.983633 0.000009
Mean Adjusted Rand Index: 0.9969
Mean Cluster Overlap: 99.9%
Mean Centroid Distance: 0.000064
================================================================================
INTERPRETATION
================================================================================
⚠️ IMPORTANT NOTE:
This compares:
- Original: Thetas from REGULARIZED training batches
- Unregularized: Thetas from UNREGULARIZED training batches
Both are from TRAINING (same context). Fixedgk predictions use pooled gamma/kappa
from unregularized training, so these thetas represent what fixedgk predictions
would produce. The key question is: are clusters similar enough that heterogeneity
patterns hold?
Adjusted Rand Index (ARI):
- ARI = 1.0: Perfect agreement
- ARI > 0.9: Very similar clusters
- ARI > 0.7: Similar clusters
- ARI < 0.5: Different clusters
Cluster Overlap (after label matching):
- 99.9% of patients in matching clusters
- This is more interpretable than ARI when clusters are imbalanced
✅ High overlap (99.9%) + low centroid distance (0.000064)
→ Clusters are VERY SIMILAR despite low ARI
→ Heterogeneity analysis likely doesn't need to be remade
✓ Saved results to: /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/results/clustering_similarity_test.csv
================================================================================
✓ Test completed successfully!